{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quadrotor Example\n",
    "This notebook will demonstrate how to set up and solve a trajectory optimization problem for a quadrotor. In particular, it will highlight how TrajectoryOptimization.jl accounts for the group structure of 3D rotations.\n",
    "\n",
    "### Loading the Required Packages\n",
    "To define the quadrotor model, we import `RobotDynamics` and `Rotations`, and use `TrajectoryOptimization` to define the problem. We load in `StaticArrays` and `LinearAlgebra` to help with the setup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m environment at `~/.julia/dev/TrajectoryOptimization/examples/Project.toml`\n"
     ]
    }
   ],
   "source": [
    "# import Pkg; Pkg.activate(@__DIR__); Pkg.instantiate();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Precompiling RobotDynamics [38ceca67-d8d3-44e8-9852-78a5596522e1]\n",
      "└ @ Base loading.jl:1260\n",
      "┌ Info: Precompiling TrajectoryOptimization [c79d492b-0548-5874-b488-5a62c1d9d0ca]\n",
      "└ @ Base loading.jl:1260\n"
     ]
    }
   ],
   "source": [
    "using RobotDynamics, Rotations\n",
    "using TrajectoryOptimization\n",
    "using StaticArrays, LinearAlgebra"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the model\n",
    "We could use the quadrotor model defined in `RobotZoo.jl`, but instead we'll go through the details of using the `RigidBody` interface in `RobotDyanmics`.\n",
    "\n",
    "We start by defining our new `Quadrotor` type, which inherits from `RigidBody{R}`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct Quadrotor{R} <: RigidBody{R}\n",
    "    n::Int\n",
    "    m::Int\n",
    "    mass::Float64\n",
    "    J::Diagonal{Float64,SVector{3,Float64}}\n",
    "    Jinv::Diagonal{Float64,SVector{3,Float64}}\n",
    "    gravity::SVector{3,Float64}\n",
    "    motor_dist::Float64\n",
    "    kf::Float64\n",
    "    km::Float64\n",
    "    bodyframe::Bool  # velocity in body frame?\n",
    "    info::Dict{Symbol,Any}\n",
    "end\n",
    "\n",
    "function Quadrotor{R}(;\n",
    "        mass=0.5,\n",
    "        J=Diagonal(@SVector [0.0023, 0.0023, 0.004]),\n",
    "        gravity=SVector(0,0,-9.81),\n",
    "        motor_dist=0.1750,\n",
    "        kf=1.0,\n",
    "        km=0.0245,\n",
    "        bodyframe=false,\n",
    "        info=Dict{Symbol,Any}()) where R\n",
    "    Quadrotor{R}(13,4,mass,J,inv(J),gravity,motor_dist,kf,km,bodyframe,info)\n",
    "end\n",
    "\n",
    "(::Type{Quadrotor})(;kwargs...) = Quadrotor{UnitQuaternion{Float64}}(;kwargs...)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "where `R` is the rotation parameterization being used, typically one of `UnitQuaternion{T}`, `MRP{T}`, or `RodriguesParam{T}`. \n",
    "\n",
    "We now need to define the number of control inputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "RobotDynamics.control_dim(::Quadrotor) = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to define the dynamics of our quadrotor, which we do by simply defining the forces and moments acting on our quadrotor for a given state and control, as well as some \"getter\" methods for our inertial properties.\n",
    "\n",
    "It's important to note that the force is in the world frame, and torque is in the body frame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "function RobotDynamics.forces(model::Quadrotor, x, u)\n",
    "    q = orientation(model, x)\n",
    "    kf = model.kf\n",
    "    g = model.gravity\n",
    "    m = model.mass\n",
    "\n",
    "    # Extract motor speeds\n",
    "    w1 = u[1]\n",
    "    w2 = u[2]\n",
    "    w3 = u[3]\n",
    "    w4 = u[4]\n",
    "\n",
    "    # Calculate motor forces\n",
    "    F1 = max(0,kf*w1);\n",
    "    F2 = max(0,kf*w2);\n",
    "    F3 = max(0,kf*w3);\n",
    "    F4 = max(0,kf*w4);\n",
    "    F = @SVector [0., 0., F1+F2+F3+F4] #total rotor force in body frame\n",
    "\n",
    "    m*g + q*F # forces in world frame\n",
    "end\n",
    "\n",
    "function RobotDynamics.moments(model::Quadrotor, x, u)\n",
    "\n",
    "    kf, km = model.kf, model.km\n",
    "    L = model.motor_dist\n",
    "\n",
    "    # Extract motor speeds\n",
    "    w1 = u[1]\n",
    "    w2 = u[2]\n",
    "    w3 = u[3]\n",
    "    w4 = u[4]\n",
    "    \n",
    "    # Calculate motor forces\n",
    "    F1 = max(0,kf*w1);\n",
    "    F2 = max(0,kf*w2);\n",
    "    F3 = max(0,kf*w3);\n",
    "    F4 = max(0,kf*w4);\n",
    "\n",
    "    # Calculate motor torques\n",
    "    M1 = km*w1;\n",
    "    M2 = km*w2;\n",
    "    M3 = km*w3;\n",
    "    M4 = km*w4;\n",
    "    tau = @SVector [L*(F2-F4), L*(F3-F1), (M1-M2+M3-M4)] #total rotor torque in body frame\n",
    "end\n",
    "\n",
    "RobotDynamics.inertia(model::Quadrotor) = model.J\n",
    "RobotDynamics.inertia_inv(model::Quadrotor) = model.Jinv\n",
    "RobotDynamics.mass(model::Quadrotor) = model.mass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And with that our model is defined!\n",
    "\n",
    "## Setting up our problem\n",
    "For our trajectory optimization problem, we're going to have the quadrotor do a \"zig-zag\" pattern. We can do this via objective/cost function manipulation. We start by creating our quadrotor model and defining our integration scheme:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.05"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Set up model and discretization\n",
    "model = Quadrotor();\n",
    "n,m = size(model)\n",
    "N = 101                # number of knot points\n",
    "tf = 5.0               # total time (sec)\n",
    "dt = tf/(N-1)          # time step (sec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to set up the initial and final conditions for our quadrotor, which we want to move 20 meters in the x-direction. We can build the state piece-by-piece using the `RobotDynamics.build_state` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x0_pos = SA[0, -10, 1.]\n",
    "xf_pos = SA[0, +10, 1.]\n",
    "x0 = RobotDynamics.build_state(model, x0_pos, UnitQuaternion(I), zeros(3), zeros(3))\n",
    "xf = RobotDynamics.build_state(model, xf_pos, UnitQuaternion(I), zeros(3), zeros(3));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the cost function\n",
    "We now create a cost function that encourages a \"zig-zag\" pattern for the quadrotor. We set up a few waypoints at specific times, and impose a high cost for being far from those locations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up waypoints\n",
    "wpts = [SA[+10, 0, 1.],\n",
    "        SA[-10, 0, 1.],\n",
    "        xf_pos]\n",
    "times = [33, 66, 101]   # in knot points\n",
    "\n",
    "# Set up nominal costs\n",
    "Q = Diagonal(RobotDynamics.fill_state(model, 1e-5, 1e-5, 1e-3, 1e-3))\n",
    "R = Diagonal(@SVector fill(1e-4, 4))\n",
    "q_nom = UnitQuaternion(I)\n",
    "v_nom = zeros(3)\n",
    "ω_nom = zeros(3)\n",
    "x_nom = RobotDynamics.build_state(model, zeros(3), q_nom, v_nom, ω_nom)\n",
    "cost_nom = LQRCost(Q, R, x_nom)\n",
    "\n",
    "# Set up waypoint costs\n",
    "Qw_diag = RobotDynamics.fill_state(model, 1e3,1,1,1)\n",
    "Qf_diag = RobotDynamics.fill_state(model, 10., 100, 10, 10)\n",
    "costs = map(1:length(wpts)) do i\n",
    "    r = wpts[i]\n",
    "    xg = RobotDynamics.build_state(model, r, q_nom, v_nom, ω_nom)\n",
    "    if times[i] == N\n",
    "        Q = Diagonal(Qf_diag)\n",
    "    else\n",
    "        Q = Diagonal(1e-3*Qw_diag)\n",
    "    end\n",
    "\n",
    "    LQRCost(Q, R, xg)\n",
    "end\n",
    "\n",
    "# Build Objective\n",
    "costs_all = map(1:N) do k\n",
    "    i = findfirst(x->(x ≥ k), times)\n",
    "    if k ∈ times\n",
    "        costs[i]\n",
    "    else\n",
    "        cost_nom\n",
    "    end\n",
    "end\n",
    "obj = Objective(costs_all);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization\n",
    "We initialize the solver with a simple hover trajectory that keeps the quadrotor hovering at the initial position."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "u0 = @SVector fill(0.5*model.mass/m, m)\n",
    "U_hover = [copy(u0) for k = 1:N-1]; # initial hovering control trajectory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constraints\n",
    "For this problem, we only impose bounds on the controls."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "conSet = ConstraintList(n,m,N)\n",
    "bnd = BoundConstraint(n,m, u_min=0.0, u_max=12.0)\n",
    "add_constraint!(conSet, bnd, 1:N-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building the Problem\n",
    "We now build the trajectory optimization problem, providing a dynamically-feasible initialization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = Problem(model, obj, xf, tf, x0=x0, constraints=conSet)\n",
    "initial_controls!(prob, U_hover)\n",
    "rollout!(prob);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solving the Problem using ALTRO\n",
    "With our problem set up, can we solve it using any of the supported solvers. We'll use ALTRO:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Precompiling Altro [5dcf52e5-e2fb-48e0-b826-96f46d2e3e73]\n",
      "└ @ Base loading.jl:1260\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cost: 0.2992834848449584\n",
      "Constraint violation: 7.598522921981044e-10\n",
      "Iterations: 90\n"
     ]
    }
   ],
   "source": [
    "using Altro\n",
    "opts = SolverOptions(\n",
    "    penalty_scaling=100.,\n",
    "    penalty_initial=0.1,\n",
    ")\n",
    "\n",
    "solver = ALTROSolver(prob, opts);\n",
    "solve!(solver)\n",
    "println(\"Cost: \", cost(solver))\n",
    "println(\"Constraint violation: \", max_violation(solver))\n",
    "println(\"Iterations: \", iterations(solver))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing the solution\n",
    "We can use `TrajOptPlots` to visualize the solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Precompiling TrajOptPlots [7770976a-8dee-4930-bf39-a1782fd21ce6]\n",
      "└ @ Base loading.jl:1260\n",
      "┌ Info: MeshCat server started. You can open the visualizer by visiting the following URL in your browser:\n",
      "│ http://localhost:8701\n",
      "└ @ MeshCat /home/bjack205/.julia/packages/MeshCat/ECbzr/src/visualizer.jl:73\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "    <div style=\"height: 500px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both\">\n",
       "    <iframe src=\"http://localhost:8701\" style=\"width: 100%; height: 100%; border: none\"></iframe>\n",
       "    </div>\n"
      ],
      "text/plain": [
       "MeshCat.DisplayedVisualizer(MeshCat.CoreVisualizer(MeshCat.SceneTrees.SceneNode(nothing, nothing, Dict{String,Array{UInt8,1}}(), nothing, Dict{String,MeshCat.SceneTrees.SceneNode}()), Set(Any[]), ip\"127.0.0.1\", 8701))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using TrajOptPlots\n",
    "using MeshCat\n",
    "using Plots\n",
    "\n",
    "vis = Visualizer()\n",
    "render(vis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the visualization, we use `MeshIO v0.3` and `FileIO` to load in a mesh file. For the visualization, we need to tell `TrajOptPlots` what geometry to display, which we do by defining the `_set_mesh!` method for our model. Since our model is a `RigidBody`, `TrajOptPlots` already knows how to display it once the robot geometry is defined."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MeshCat Visualizer with path /meshcat/robot/geom at http://localhost:8701"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using FileIO, MeshIO\n",
    "function TrajOptPlots._set_mesh!(vis, model::Quadrotor)\n",
    "    obj = joinpath(@__DIR__, \"quadrotor.obj\")\n",
    "    quad_scaling = 0.085\n",
    "    robot_obj = FileIO.load(obj)\n",
    "    robot_obj.vertices .*= quad_scaling\n",
    "    mat = MeshPhongMaterial(color=colorant\"black\")\n",
    "    setobject!(vis[\"geom\"], robot_obj, mat)\n",
    "end\n",
    "TrajOptPlots.set_mesh!(vis, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize!(vis, solver);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.4.2",
   "language": "julia",
   "name": "julia-1.4"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.4.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
